#include "fintrf.h"
!     
!=======================================================================
! Gateway subroutine
subroutine mexfunction(nlhs, plhs, nrhs, prhs)

! Declarations
implicit none

! mexFunction arguments:
mwPointer plhs(*), prhs(*)
integer nlhs, nrhs

! Function declarations:
mwPointer mxGetPr
mwPointer mxCreateDoubleMatrix, mxCreateNumericArray
mwPointer mxGetM, mxGetN

! Pointers to input/output mxArrays:
mwPointer pr1 ,pr2 ,pr3 ,pr4 ,pr5 ,pr6 ,pr7 ,pr8 ,pr9 ,pr10,&
          pr11,pr12,pr13,&
          pr1_out,pr2_out,pr3_out,pr4_out

! Array information:
integer:: nx,ny,Nmax,firstOn,order_app
integer convertToInt

! Arguments for mxCreateNumericArray
integer*4 mxClassIDFromClassName
integer*4 classid
integer*4 complexflag
mwSize ndim3,ndim4,ndim5,ndim6
mwSize dims3(3),dims4(4),dims5(5),dims6(6)

!-----------------------------------------------------------------------
! Check for proper number of arguments. 
if(nrhs .ne. 13) then
   call mexErrMsgIdAndTxt ('MATLAB:dblmat:nInput','13 input required.')
end if
if(nlhs .gt. 4) then
   call mexErrMsgIdAndTxt ('MATLAB:dblmat:nOutput','Too many output arguments.')
end if
         
! Create Fortran arrays from the input argument.
pr1  = mxGetPr(prhs(1))
pr2  = mxGetPr(prhs(2))
pr3  = mxGetPr(prhs(3))
pr4  = mxGetPr(prhs(4))
pr5  = mxGetPr(prhs(5))
pr6  = mxGetPr(prhs(6))
pr7  = mxGetPr(prhs(7))
pr8  = mxGetPr(prhs(8))
pr9  = mxGetPr(prhs(9))
pr10 = mxGetPr(prhs(10))
pr11 = mxGetPr(prhs(11))
pr12 = mxGetPr(prhs(12))
pr13 = mxGetPr(prhs(13))

! Get the size of the input array.
Nmax      = convertToInt(%VAL(pr9))
firstOn   = convertToInt(%VAL(pr10))
order_app = convertToInt(%VAL(pr11))
ny        = convertToInt(%VAL(pr12))
nx        = convertToInt(%VAL(pr13))


! Create matrix for the return argument.
!for Ey: ny * (firstOn+Nmax) * nx
classid = mxClassIDFromClassName('double')        
complexflag = 0                                 ! 0 for real data
ndim3    = 3
dims3(1) = ny
dims3(2) = Nmax+firstOn
dims3(3) = nx
plhs(1) = mxCreateNumericArray(ndim3, dims3, classid, complexflag)
pr1_out = mxGetPr(plhs(1))

!for Eyy: ny * (firstOn+Nmax) * nx * nx
ndim4    = 4                                     
dims4(1) = ny
dims4(2) = Nmax+firstOn
dims4(3) = nx
dims4(4) = nx
plhs(2) = mxCreateNumericArray(ndim4, dims4, classid, complexflag)
pr2_out = mxGetPr(plhs(2))

!for Eyyy: ny * (firstOn+Nmax) * nx * nx * nx
ndim5    = 5
dims5(1) = ny
dims5(2) = Nmax+firstOn
dims5(3) = nx
dims5(4) = nx
dims5(5) = nx
plhs(3) = mxCreateNumericArray(ndim5, dims5, classid, complexflag)
pr3_out = mxGetPr(plhs(3))

!for E4y: ny * (firstOn+Nmax) * nx * nx * nx * nx
ndim6    = 6
dims6(1) = ny
dims6(2) = Nmax+firstOn
dims6(3) = nx
dims6(4) = nx
dims6(5) = nx
dims6(6) = nx
plhs(4) = mxCreateNumericArray(ndim6, dims6, classid, complexflag)
pr4_out = mxGetPr(plhs(4))

! Call the computational routine.
call fortranSub(%VAL(pr1_out),%VAL(pr2_out),%VAL(pr3_out),%VAL(pr4_out),&
   %VAL(pr1) ,%VAL(pr2) ,%VAL(pr3) ,%VAL(pr4) ,%VAL(pr5) ,&
   %VAL(pr6) ,%VAL(pr7) ,%VAL(pr8) ,Nmax      ,firstOn   ,&
   order_app,ny,nx)
return
end

function convertToInt(matlab_dble) result( int_val )

    implicit none

    REAL*8, INTENT(IN):: matlab_dble    
    INTEGER:: int_val

    int_val = INT(matlab_dble)

end function convertToInt


!===============================================================
! This function computes the conditional moments up to k_period into the
! future for the variables in the g-function.
! In terms of the notation we solve: p(x,sig) = E_t[r(x_t+1,sig)]
! The computations are derived by using the Perturbation On Perturbation (POP) method.
! IMPORTANT: 1) This is for a "level" approximation under perfect foresight
!            2) The first time period in px, pxx, pxxx is the current time
!            period, i.e. we reproduce gx, gxx, gxxx, if firstOn = 1.
!
! Computational subroutine
subroutine fortranSub(px,pxx,pxxx,p4x,&
 gx,gxx,gxxx,g4x,hx,hxx,hxxx,h4x,k_period,firstOn,order_app,ny,nx)

IMPLICIT NONE
INTEGER, INTENT(IN) :: k_period,firstOn,order_app,ny,nx
REAL*8,  INTENT(OUT):: px(ny,k_period+firstOn,nx),pxx(ny,k_period+firstOn,nx,nx),&
                       pxxx(ny,k_period+firstOn,nx,nx,nx),p4x(ny,k_period+firstOn,nx,nx,nx,nx)
REAL*8,  INTENT(IN) :: gx(ny,nx),gxx(ny,nx,nx),gxxx(ny,nx,nx,nx),g4x(ny,nx,nx,nx,nx),&
                       hx(nx,nx),hxx(nx,nx,nx),hxxx(nx,nx,nx,nx),h4x(nx,nx,nx,nx,nx)

! Declaring the remaining variables
INTEGER i,j,alfa1,alfa2,alfa3,alfa4,gama1,gama2,gama3,gama4,startIndex
REAL*8  rx(1,nx),rxx(nx,nx),rxxx(nx,nx,nx),r4x(nx,nx,nx,nx),tmp(1,1),&
        tmp_1(1,1),tmp_2(1,1),tmp_3(1,1),tmp_4(1,1),tmp_5(1,1),tmp_6(1,1),&
        tmp_7(1,1),tmp_8(1,1),tmp_9(1,1),tmp_10(1,1),tmp_11(1,1),tmp_12(1,1),&
        tmp_13(1,1),tmp_14(1,1),tmp_15(1,1)

character*120 line


! To print to a file
!open(unit=6,file='tmp.txt',status='unknown')
!write(6,*) 'nx = ',nx,'ny = ',ny

! The first time period
IF (firstOn == 1) THEN
   px(1:ny,1,1:nx)                  = gx
   pxx(1:ny,1,1:nx,1:nx)            = gxx
   pxxx(1:ny,1,1:nx,1:nx,1:nx)      = gxxx
   p4x(1:ny,1,1:nx,1:nx,1:nx,1:nx)  = g4x
   startIndex                       = 2
ELSE
   startIndex     = 1
END IF
DO i=1,ny
   rx   = gx(i:i,1:nx)
   rxx  = gxx(i,1:nx,1:nx)
   rxxx = gxxx(i,1:nx,1:nx,1:nx)
   r4x  = g4x(i,1:nx,1:nx,1:nx,1:nx)
    
   DO j=startIndex,k_period+firstOn
      ! ************** The first order effects *****************
      px(i:i,j,:) = MATMUL(rx(1:1,1:nx),hx)
        
      !*************** The second order effects ***************
      IF (order_app > 1) THEN
         DO alfa1=1,nx
            DO alfa2=alfa1,nx
               pxx(i,j,alfa1:alfa1,alfa2:alfa2) = MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxx,hx(1:nx,alfa2:alfa2)))&
                                      + MATMUL(rx(1:1,1:nx),hxx(1:nx,alfa1:alfa1,alfa2))
            END DO
            IF (alfa1 > 1) THEN
               pxx(i,j,alfa1:alfa1,1:alfa1-1) = TRANSPOSE(pxx(i,j,1:alfa1-1,alfa1:alfa1))
            END IF
         END DO
      END IF
      ! ***************** The third order effects *****************
      IF (order_app > 2) THEN
         DO alfa1=1,nx
            DO alfa2=alfa1,nx
               DO alfa3=alfa2,nx
                  tmp = 0._8
                  DO gama3=1,nx
                     tmp = tmp + MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxxx(1:nx,1:nx,gama3),hx(1:nx,alfa2:alfa2)))*hx(gama3,alfa3)
                  END DO
                  pxxx(i,j,alfa1,alfa2:alfa2,alfa3:alfa3) = tmp &
                            +MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxx,hxx(1:nx,alfa2:alfa2,alfa3)))&
                            +MATMUL(TRANSPOSE(hxx(1:nx,alfa1:alfa1,alfa3)),MATMUL(rxx,hx(1:nx,alfa2:alfa2)))&
                            +MATMUL(TRANSPOSE(hxx(1:nx,alfa1:alfa1,alfa2)),MATMUL(rxx,hx(1:nx,alfa3:alfa3)))&
                            +MATMUL(rx(1:1,1:nx),hxxx(1:nx,alfa1:alfa1,alfa2,alfa3))
                        
                  ! Using symmetry for alfa1 and alfa2
                  IF (alfa1 == alfa2 .AND. alfa2 .NE. alfa3) THEN  !alfa1==alfa2~=alfa3
                     !pxxx(i,j,alfa1,alfa1,alfa3)= pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa1,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa3,alfa1,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3)
                  END IF
                  ! Using symmetry for alfa2 and alfa3
                  IF (alfa1 .NE. alfa2 .AND. alfa2 == alfa3) THEN   !alfa1~=alfa2==alfa3
                     !pxxx(i,j,alfa1,alfa2,alfa2)= pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa2,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa2,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3)
                  END IF
                  ! Using symmetry for alfa1,alfa2, and alfa3
                  IF (alfa1 .NE. alfa2 .AND. alfa1 .NE. alfa3 .AND. alfa2 .NE. alfa3) THEN !alfa1~=alfa2~=alfa3
                     !pxxx(i,j,alfa1,alfa2,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa1,alfa3,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa3,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa3,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa2,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3)
                     pxxx(i,j,alfa2,alfa1,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3)
                  END IF
               END DO
            END DO
         END DO
      END IF
      ! ***************** The fourth order effects *****************      
      IF (order_app > 3) THEN
         DO alfa1=1,nx
            DO alfa2=1,nx
               DO alfa3=1,nx
                  DO alfa4=1,nx
                     tmp_1 = 0._8
                     DO gama3=1,nx
                        DO gama4=1,nx
                           tmp_1 = tmp_1 + MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(r4x(1:nx,1:nx,gama3,gama4),hx(1:nx,alfa2:alfa2)))*&
                                           hx(gama3,alfa3)*hx(gama4,alfa4)
                        END DO
                     END DO
                     tmp_2 = 0._8
                     tmp_3 = 0._8
                     tmp_4 = 0._8
                     DO gama3=1,nx
                        tmp_2(1:1,1:1) = tmp_2(1:1,1:1) + MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxxx(1:nx,1:nx,gama3),hx(1:nx,alfa2:alfa2)))*hxx(gama3,alfa3,alfa4)
                        tmp_3(1:1,1:1) = tmp_3(1:1,1:1) + MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxxx(1:nx,1:nx,gama3),hxx(1:nx,alfa2:alfa2,alfa4)))*hx(gama3,alfa3)
                        tmp_4(1:1,1:1) = tmp_4(1:1,1:1) + MATMUL(TRANSPOSE(hxx(1:nx,alfa1:alfa1,alfa4)),MATMUL(rxxx(1:nx,1:nx,gama3),hx(1:nx,alfa2:alfa2)))*hx(gama3,alfa3)
                     END DO
                     tmp_5 = 0._8
                     tmp_8 = 0._8
                     tmp_11= 0._8
                     DO gama4=1,nx
                        tmp_5 = tmp_5 + MATMUL(TRANSPOSE(hx(:,alfa1:alfa1)),MATMUL(rxxx(:,:,gama4),hxx(:,alfa2:alfa2,alfa3)))*hx(gama4,alfa4)
                        tmp_8 = tmp_8 + MATMUL(TRANSPOSE(hxx(:,alfa1:alfa1,alfa3)),MATMUL(rxxx(:,:,gama4),hx(:,alfa2:alfa2)))*hx(gama4,alfa4)
                        tmp_11= tmp_11+ MATMUL(TRANSPOSE(hxx(:,alfa1:alfa1,alfa2)),MATMUL(rxxx(:,:,gama4),hx(:,alfa3:alfa3)))*hx(gama4,alfa4)
                     END DO
                     tmp_6 = MATMUL(TRANSPOSE(hx(1:nx,alfa1:alfa1)),MATMUL(rxx,hxxx(1:nx,alfa2:alfa2,alfa3,alfa4)))
                     tmp_7 = MATMUL(TRANSPOSE(hxx(1:nx,alfa1:alfa1,alfa4)),MATMUL(rxx,hxx(1:nx,alfa2:alfa2,alfa3)))
                     tmp_9 = MATMUL(TRANSPOSE(hxx(:,alfa1:alfa1,alfa3)),MATMUL(rxx,hxx(:,alfa2:alfa2,alfa4)))
                     tmp_10= MATMUL(TRANSPOSE(hxxx(:,alfa1:alfa1,alfa3,alfa4)),MATMUL(rxx,hx(:,alfa2:alfa2)))
                     tmp_12= MATMUL(TRANSPOSE(hxx(:,alfa1:alfa1,alfa2)),MATMUL(rxx,hxx(:,alfa3:alfa3,alfa4)))
                     tmp_13= MATMUL(TRANSPOSE(hxxx(:,alfa1:alfa1,alfa2,alfa4)),MATMUL(rxx,hx(:,alfa3:alfa3)))
                     tmp_14= MATMUL(TRANSPOSE(hxxx(:,alfa1:alfa1,alfa2,alfa3)),MATMUL(rxx,hx(:,alfa4:alfa4)))
                     tmp_15= MATMUL(rx(1:1,1:nx),h4x(1:nx,alfa1:alfa1,alfa2,alfa3,alfa4))

                     p4x(i,j,alfa1,alfa2,alfa3,alfa4)= tmp_1(1,1) + tmp_2(1,1) + tmp_3(1,1) + tmp_4(1,1) + tmp_5(1,1) +&
                                                       tmp_6(1,1) + tmp_7(1,1) + tmp_8(1,1) + tmp_9(1,1) + tmp_10(1,1)+&
                                                       tmp_11(1,1)+ tmp_12(1,1)+ tmp_13(1,1)+ tmp_14(1,1)+ tmp_15(1,1)      
                  END DO
               END DO
            END DO
         END DO 
      END IF 

      ! Updating rx
      rx(1,:) = px(i,j,1:nx)
      rxx     = pxx(i,j,1:nx,1:nx)
      rxxx    = pxxx(i,j,1:nx,1:nx,1:nx)
      r4x     = p4x(i,j,1:nx,1:nx,1:nx,1:nx)
   END DO
END DO

END SUBROUTINE fortranSub


